package easy

import (
	"errors"
	"net"
	"strings"

	"gitee.com/haodreams/libs/config"
	"gitee.com/haodreams/libs/namelist"
	"gitee.com/haodreams/libs/routine"
)

// 限制ip访问
type SafeListener struct {
	net.Listener
	nl   *namelist.Namelist //白名单或者黑名单
	logf func(v ...interface{})
}

func NewSafeListener(logf func(v ...interface{})) *SafeListener {
	m := new(SafeListener)
	m.logf = logf
	return m
}

func (m *SafeListener) Accept() (conn net.Conn, err error) {
	for routine.IsRunning() {
		conn, err = m.Listener.Accept()
		if err != nil {
			return
		}
		if m.nl.IsEmpty() {
			return
		}
		// 获取请求的 IP 地址
		clientIP := conn.RemoteAddr().String()
		pos := strings.LastIndex(clientIP, ":")
		if pos < 0 {
			continue
		}
		clientIP = clientIP[:pos]
		if clientIP == "127.0.0.1" {
			return
		}

		if m.nl.In(clientIP) {
			return
		}
		if m.logf != nil {
			m.logf("禁止访问:", clientIP)
		}
		conn.Close()
	}
	return
}

func (m *SafeListener) AcceptTCP() (conn *net.TCPConn, err error) {
	c, err := m.Accept()
	if err != nil {
		return
	}
	conn, ok := c.(*net.TCPConn)
	if !ok {
		err = errors.New("类型错误")
		return
	}
	return
}

// 设置名单列表
func (m *SafeListener) SetNameList(nl *namelist.Namelist) {
	m.nl = nl
}

// 设置获取名单列表
func (m *SafeListener) GetNameList() *namelist.Namelist {
	return m.nl
}

// 从配置中加载认证白名单
func LoadAuthWhiteList() *namelist.Namelist {
	nl := LoadNameListFromKey("http.whitelist.auth")
	if !nl.In("127.0.0.1") {
		nl.Set("127.0.0.1")
	}
	return nl
}

// 从配置中加载访问白名单
func LoadAccessWhiteList() *namelist.Namelist {
	return LoadNameListFromKey("http.whitelist.access")
}

// 从配置文件中加载名单列表
func LoadNameListFromKey(key string) *namelist.Namelist {
	nl := namelist.NewNamelist()
	listhoststring := config.String(key)
	if listhoststring != "" {
		hosts := strings.Split(listhoststring, ",")
		for _, host := range hosts {
			if host == "" {
				continue
			}
			nl.Set(host)
		}
	}

	return nl
}
